Inference in Variational Autoencoders with Different Monte Carlo Sample Sizes
Draft
Please do not share or link.
In a previous post, I demonstrated how to use leverage Keras' modular design to implement variational inference in a way that makes it easy to tweak hyperparameters, adapt to related models, and extend to the more sophisticated methods in the current research.
eps = Input(shape=(mc_samples, latent_dim))
Everything else remains exactly the same. The Multiply layer will automatically broadcast eps which has shape (batch_size, mc_samples, latent_dim) with sigma which has shape (batch_size, latent_dim) and output shape (batch_size, mc_samples, latent_dim). Since the subsequent layers do not operate on the which will then be propagated to the final output.
Reparameterization with simple location-scale transformation using Keras merge layers.
Reparameterization with simple location-scale transformation using Keras merge layers.
We expand the targets to 3d a array np.expand_dims(x_train, axis=1) to be of shape (batch_size, 1, original_dim) so that the loss function can broadcast with the output with shape (batch_size, mc_samples, original_dim). It is important to make the distinction between the log likelihood of the mean over outputs, versus the mean of the log likelihood over the outputs. Since we require the expected log likelihood, we are interested in the latter.
eps_train = np.random.randn(len(x_train), mc_samples, latent_dim) eps_test = np.random.randn(len(x_test), mc_samples, latent_dim) vae.fit( [x_train, eps_train], np.expand_dims(x_train, axis=1), shuffle=True, epochs=epochs, batch_size=batch_size, validation_data=( [x_test, eps_test], np.expand_dims(x_test, axis=1) ) )
For every data point, there are mc_samples reconstructions.
recons = vae.predict([x_test[:1], eps_test[:1]]).squeeze() plt.figure(figsize=(10, 10)) plt.imshow(np.block(list(map(list, recons.reshape(5, 5, 28, 28)))), cmap='gray') plt.show()
plot here
Appendix
Below you can find:
- The accompanying Jupyter Notebook used to generate the diagrams and plots in this post.
- The above snippets combined in a single executable Python file:
vae/variational_autoencoder_mc_samples.py (Source)
import numpy as np import matplotlib.pyplot as plt from scipy.stats import norm from keras import backend as K from keras.layers import Input, Dense, Lambda, Layer, Add, Multiply from keras.models import Model, Sequential from keras.datasets import mnist batch_size = 100 original_dim = 784 latent_dim = 2 intermediate_dim = 256 epochs = 50 epsilon_std = 1.0 def nll(y_true, y_pred): """ Bernoulli negative log likelihood. """ # keras.losses.binary_crossentropy gives the mean # over the last axis. We require the sum. return K.sum(K.binary_crossentropy(y_true, y_pred), axis=-1) class KLDivergenceLayer(Layer): """ Identity transform layer that adds KL divergence to the final model loss. """ def __init__(self, *args, **kwargs): self.is_placeholder = True super(KLDivergenceLayer, self).__init__(*args, **kwargs) def call(self, inputs): mu, log_var = inputs kl_batch = - .5 * K.sum(1 + log_var - K.square(mu) - K.exp(log_var), axis=-1) self.add_loss(K.mean(kl_batch), inputs=inputs) return inputs x = Input(shape=(original_dim,)) h = Dense(intermediate_dim, activation='relu')(x) z_mu = Dense(latent_dim)(h) z_log_var = Dense(latent_dim)(h) z_mu, z_log_var = KLDivergenceLayer()([z_mu, z_log_var]) z_sigma = Lambda(lambda t: K.exp(.5*t))(z_log_var) eps = Input(tensor=K.random_normal(shape=(K.shape(x)[0], latent_dim))) z_eps = Multiply()([z_sigma, eps]) z = Add()([z_mu, z_eps]) decoder = Sequential([ Dense(intermediate_dim, input_dim=latent_dim, activation='relu'), Dense(original_dim, activation='sigmoid') ]) x_mean = decoder(z) vae = Model(inputs=[x, eps], outputs=x_mean) vae.compile(optimizer='rmsprop', loss=nll) # train the VAE on MNIST digits (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train = x_train.reshape(-1, original_dim) / 255. x_test = x_test.reshape(-1, original_dim) / 255. vae.fit(x_train, x_train, shuffle=True, epochs=epochs, batch_size=batch_size, validation_data=(x_test, x_test)) encoder = Model(x, z_mu) # display a 2D plot of the digit classes in the latent space z_test = encoder.predict(x_test, batch_size=batch_size) plt.figure(figsize=(6, 6)) plt.scatter(z_test[:, 0], z_test[:, 1], c=y_test, alpha=.4, s=3**2, cmap='viridis') plt.colorbar() plt.show() # display a 2D manifold of the digits n = 15 # figure with 15x15 digits digit_size = 28 # linearly spaced coordinates on the unit square were transformed # through the inverse CDF (ppf) of the Gaussian to produce values # of the latent variables z, since the prior of the latent space # is Gaussian u_grid = np.dstack(np.meshgrid(np.linspace(0.05, 0.95, n), np.linspace(0.05, 0.95, n))) z_grid = norm.ppf(u_grid) x_decoded = decoder.predict(z_grid.reshape(n*n, 2)) x_decoded = x_decoded.reshape(n, n, digit_size, digit_size) plt.figure(figsize=(10, 10)) plt.imshow(np.block(list(map(list, x_decoded))), cmap='gray') plt.show()
